/**
* Copyright 2012 Nikita Koksharov
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.corundumstudio.socketio.transport;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.corundumstudio.socketio.Configuration;
import com.corundumstudio.socketio.SocketIOChannelInitializer;
import com.corundumstudio.socketio.Transport;
import com.corundumstudio.socketio.handler.AuthorizeHandler;
import com.corundumstudio.socketio.handler.ClientHead;
import com.corundumstudio.socketio.handler.ClientsBox;
import com.corundumstudio.socketio.messages.PacketsMessage;
import com.corundumstudio.socketio.scheduler.CancelableScheduler;
import com.corundumstudio.socketio.scheduler.SchedulerKey;
import io.netty.buffer.ByteBufHolder;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.QueryStringDecoder;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrameAggregator;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
import io.netty.util.ReferenceCountUtil;
@Sharable
public class WebSocketTransport extends ChannelInboundHandlerAdapter {
public static final String NAME = "websocket";
private static final Logger log = LoggerFactory.getLogger(WebSocketTransport.class);
private final AuthorizeHandler authorizeHandler;
private final CancelableScheduler scheduler;
private final Configuration configuration;
private final ClientsBox clientsBox;
private final boolean isSsl;
public WebSocketTransport(boolean isSsl,
AuthorizeHandler authorizeHandler, Configuration configuration,
CancelableScheduler scheduler, ClientsBox clientsBox) {
this.isSsl = isSsl;
this.authorizeHandler = authorizeHandler;
this.configuration = configuration;
this.scheduler = scheduler;
this.clientsBox = clientsBox;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof CloseWebSocketFrame) {
ctx.channel().close();
ReferenceCountUtil.release(msg);
} else if (msg instanceof BinaryWebSocketFrame
|| msg instanceof TextWebSocketFrame) {
ByteBufHolder frame = (ByteBufHolder) msg;
ClientHead client = clientsBox.get(ctx.channel());
if (client == null) {
log.debug("Client with was already disconnected. Channel closed!");
ctx.channel().close();
frame.release();
return;
}
ctx.pipeline().fireChannelRead(new PacketsMessage(client, frame.content(), Transport.WEBSOCKET));
frame.release();
} else if (msg instanceof FullHttpRequest) {
FullHttpRequest req = (FullHttpRequest) msg;
QueryStringDecoder queryDecoder = new QueryStringDecoder(req.uri());
String path = queryDecoder.path();
List<String> transport = queryDecoder.parameters().get("transport");
List<String> sid = queryDecoder.parameters().get("sid");
if (transport != null && NAME.equals(transport.get(0))) {
try {
if (!configuration.getTransports().contains(Transport.WEBSOCKET)) {
log.debug("{} transport not supported by configuration.", Transport.WEBSOCKET);
ctx.channel().close();
return;
}
if (sid != null && sid.get(0) != null) {
final UUID sessionId = UUID.fromString(sid.get(0));
handshake(ctx, sessionId, path, req);
} else {
ClientHead client = ctx.channel().attr(ClientHead.CLIENT).get();
// first connection
handshake(ctx, client.getSessionId(), path, req);
}
} finally {
req.release();
}
} else {
ctx.fireChannelRead(msg);
}
} else {
ctx.fireChannelRead(msg);
}
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
ClientHead client = clientsBox.get(ctx.channel());
if (client != null && client.isTransportChannel(ctx.channel(), Transport.WEBSOCKET)) {
ctx.flush();
} else {
super.channelReadComplete(ctx);
}
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
ClientHead client = clientsBox.get(ctx.channel());
if (client != null && client.isTransportChannel(ctx.channel(), Transport.WEBSOCKET)) {
log.debug("channel inactive {}", client.getSessionId());
client.onChannelDisconnect();
}
super.channelInactive(ctx);
}
private void handshake(ChannelHandlerContext ctx, final UUID sessionId, String path, FullHttpRequest req) {
final Channel channel = ctx.channel();
WebSocketServerHandshakerFactory factory =
new WebSocketServerHandshakerFactory(getWebSocketLocation(req), null, true, configuration.getMaxFramePayloadLength());
WebSocketServerHandshaker handshaker = factory.newHandshaker(req);
if (handshaker != null) {
ChannelFuture f = handshaker.handshake(channel, req);
f.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
log.error("Can't handshake " + sessionId, future.cause());
return;
}
channel.pipeline().addBefore(SocketIOChannelInitializer.WEB_SOCKET_TRANSPORT, SocketIOChannelInitializer.WEB_SOCKET_AGGREGATOR,
new WebSocketFrameAggregator(configuration.getMaxFramePayloadLength()));
connectClient(channel, sessionId);
}
});
} else {
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
}
}
private void connectClient(final Channel channel, final UUID sessionId) {
ClientHead client = clientsBox.get(sessionId);
if (client == null) {
log.warn("Unauthorized client with sessionId: {} with ip: {}. Channel closed!",
sessionId, channel.remoteAddress());
channel.close();
return;
}
client.bindChannel(channel, Transport.WEBSOCKET);
authorizeHandler.connect(client);
if (client.getCurrentTransport() == Transport.POLLING) {
SchedulerKey key = new SchedulerKey(SchedulerKey.Type.UPGRADE_TIMEOUT, sessionId);
scheduler.schedule(key, new Runnable() {
@Override
public void run() {
ClientHead clientHead = clientsBox.get(sessionId);
if (clientHead != null) {
if (log.isDebugEnabled()) {
log.debug("client did not complete upgrade - closing transport");
}
clientHead.onChannelDisconnect();
}
}
}, configuration.getUpgradeTimeout(), TimeUnit.MILLISECONDS);
}
log.debug("сlient {} handshake completed", sessionId);
}
private String getWebSocketLocation(HttpRequest req) {
String protocol = "ws://";
if (isSsl) {
protocol = "wss://";
}
return protocol + req.headers().get(HttpHeaderNames.HOST) + req.uri();
}
}